{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Spurious-Motif Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "global_b = '0.333' # Set bias degree here\n",
    "data_dir = f'../data/SPMotif-{global_b}/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_house(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach HOUSE-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"house\"]] * nb_shapes # house\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name\n",
    "\n",
    "def get_cycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach cycle-shaped (directed edges) subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"dircycle\"]] * nb_shapes\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]       # 0.05 original\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name\n",
    "\n",
    "def get_crane(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach crane-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"varcycle\"]] * nb_shapes   # crane\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 254.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 18.85    #Edges: 27.05 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 253.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 18.68    #Edges: 27.97 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 256.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 3000    #Nodes: 18.82    #Edges: 29.13 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[bias,(1-bias)/2,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,bias,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,(1-bias)/2,bias])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    \n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "with open(osp.join(data_dir, 'train.npy'), 'wb') as f:\n",
    "    pickle.dump((edge_index_list, label_list, ground_truth_list, role_id_list, pos_list), f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 259.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 18.11    #Edges: 26.18 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 258.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 18.52    #Edges: 27.75 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 255.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# Graphs: 3000    # Nodes: 19.01    # Edges: 29.48 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    \n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"# Graphs: %d    # Nodes: %.2f    # Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "with open(osp.join(data_dir, 'val.npy'), 'wb') as f:\n",
    "    pickle.dump((edge_index_list, label_list, ground_truth_list, role_id_list, pos_list), f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [00:48<00:00, 41.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 89.75    #Edges: 125.26 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [00:47<00:00, 42.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 4000    #Nodes: 89.99    #Edges: 126.62 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [00:46<00:00, 42.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 6000    #Nodes: 89.54    #Edges: 127.04 \n"
     ]
    }
   ],
   "source": [
    "# no bias for test dataset\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "\n",
    "def graph_stats_large(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3,6))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(30,50))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(60,80))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3]) # uniform\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int32).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "with open(osp.join(data_dir, 'test.npy'), 'wb') as f:\n",
    "    pickle.dump((edge_index_list, label_list, ground_truth_list, role_id_list, pos_list), f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "vscode": {
   "interpreter": {
    "hash": "c71b0b87ea436ae79e2503ec051639fc2420e91bd742cb356b7debceb9d5ed19"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
